#### FUNCTION TO STANDARDIZE CAUSAL FOREST ESTIMATION FOR CONJOINT COMPARISONS ####

run_causal_forest_cjoint <- function(df,outcome,treatment_cjoint,
                                     treatment_binary,
                                     treatment_comparison,covars,
                                     seed = 10009,num.trees = 6000,
                                     tune.num.trees = 1500,
                                     tune.num.reps = 500){
  
  #filter data to treatment comparison
  df %>%
    filter({{treatment_cjoint}} %in% treatment_comparison) -> df
  
  #matrix of covariates
  df_X <- df %>%
    select(one_of(covars))
  
  #outcome
  df_y <- df %>%
    pull({{outcome}})
  
  #treatment
  df_w <- df %>%
    pull({{treatment_binary}})
  
  #estimate causal forest
  cf_out <- causal_forest(
    X = df_X,
    Y = df_y,
    W = df_w,
    seed = seed,
    num.trees = num.trees,
    tune.num.trees = tune.num.trees,
    tune.num.reps = tune.num.reps
  )
  return(cf_out)
  
}

#### FUNCTION TO TIDY RESULTS ####

tidy_causal_forest <- function(cf){
  require(grf)
  require(dplyr)
  
  #get predicted causal effect and confidence intervals
  predictions <- predict(cf,estimate.variance = TRUE)
  se <- sqrt(predictions$variance.estimates)
  
  #get the cate
  cate_est <- average_treatment_effect(cf,target.sample = "all")
  
  #build tidy object
  predictions_td <- tibble(
    tau.hat = predictions$predictions,
    se = se,
    id = rank(tau.hat),
    cate = cate_est[1],
    cate_se = cate_est[2]
  ) %>%
    bind_cols(.,cf$X.orig) %>%
    mutate(lower.90 = tau.hat - 1.64*se,
           upper.90 = tau.hat + 1.64*se,
           lower.95 = tau.hat - 1.96*se,
           upper.95 = tau.hat + 1.96*se) %>%
    mutate(party_char = case_when(
      is_democrat == 1 ~ "Democrat",
      is_republican == 1 ~ "Republican",
      is_independent == 1 ~ "Independent"),
      union_char = case_when(
        is_union == 1 ~ "Unionized",
        is_union == 0 ~ "Not Unionized"
      ),
      industry_char = case_when(
        is_blue_collar == 1 ~ "Blue-collar",
        is_white_collar == 1 ~ "White-collar"
      ),
      education_char = case_when(
        is_college == 1 ~ "College-educated",
        is_college == 0 ~ "Not college-educated"
      ))
  
  #out
  return(predictions_td)
  
}